{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a0e1dd3d",
   "metadata": {},
   "source": [
    "# A walk through latent space with Stable Diffusion\n",
    "Latent space walking, or latent space exploration, is the process of sampling a point in latent space and incrementally changing the latent representation. Its most common application is generating animations where each sampled point is fed to the decoder and is stored as a frame in the final animation. For high-quality latent representations, this produces coherent-looking animations. These animations can provide insight into the feature map of the latent space, and can ultimately lead to improvements in the training process.\n",
    "\n",
    "### References\n",
    "- https://keras.io/examples/generative/random_walks_with_stable_diffusion/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d59a04f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import inspect\n",
    "import math\n",
    "\n",
    "import torch\n",
    "from diffusers import (\n",
    "    StableDiffusionPipeline,\n",
    "    AutoencoderKL,\n",
    "    UNet2DConditionModel,\n",
    "    DDIMScheduler\n",
    ")\n",
    "from transformers import CLIPTextModel, CLIPTokenizer\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from tqdm.auto import tqdm\n",
    "from IPython.display import Image as IImage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6e8c9da",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\"\n",
    "model_path = \"CompVis/stable-diffusion-v1-4\" # you can download the model weights and save locally\n",
    "model_path = \"/data2/hy/model_weights/stable-diffusion-v1-5/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fdd42e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# add new method to original StableDiffusionPipeline\n",
    "class SDPipeline(StableDiffusionPipeline):\n",
    "    \n",
    "    @torch.no_grad()\n",
    "    def encode_text(self, prompt):\n",
    "        \"\"\"Encodes prompt into latent text encoding.\"\"\"\n",
    "        # get prompt text embeddings\n",
    "        text_inputs = self.tokenizer(\n",
    "            prompt,\n",
    "            padding=\"max_length\",\n",
    "            max_length=self.tokenizer.model_max_length,\n",
    "            return_tensors=\"pt\",\n",
    "        )\n",
    "        text_input_ids = text_inputs.input_ids\n",
    "\n",
    "        if text_input_ids.shape[-1] > self.tokenizer.model_max_length:\n",
    "            removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])\n",
    "            logger.warning(\n",
    "                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n",
    "                f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n",
    "            )\n",
    "            text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]\n",
    "        text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]\n",
    "        return text_embeddings\n",
    "    \n",
    "    @torch.no_grad()\n",
    "    def generate_image(\n",
    "        self,\n",
    "        text_embeddings,\n",
    "        height=512,\n",
    "        width=512,\n",
    "        num_inference_steps=50,\n",
    "        guidance_scale=7.5,\n",
    "        eta=0.0,\n",
    "        generator=None,\n",
    "        latents=None,\n",
    "        output_type=\"pil\"\n",
    "    ):\n",
    "        \"\"\"Generates an image based on text_embeddings.\"\"\"\n",
    "        if height % 8 != 0 or width % 8 != 0:\n",
    "            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n",
    "\n",
    "        batch_size, seq_len, _ = text_embeddings.shape\n",
    "\n",
    "        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n",
    "        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n",
    "        # corresponds to doing no classifier free guidance.\n",
    "        do_classifier_free_guidance = guidance_scale > 1.0\n",
    "        # get unconditional embeddings for classifier free guidance\n",
    "        if do_classifier_free_guidance:\n",
    "            uncond_tokens = [\"\"] * batch_size\n",
    "\n",
    "            uncond_input = self.tokenizer(\n",
    "                uncond_tokens,\n",
    "                padding=\"max_length\",\n",
    "                max_length=self.tokenizer.model_max_length,\n",
    "                truncation=True,\n",
    "                return_tensors=\"pt\",\n",
    "            )\n",
    "            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n",
    "\n",
    "            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n",
    "            seq_len = uncond_embeddings.shape[1]\n",
    "\n",
    "            # For classifier free guidance, we need to do two forward passes.\n",
    "            # Here we concatenate the unconditional and text embeddings into a single batch\n",
    "            # to avoid doing two forward passes\n",
    "            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n",
    "\n",
    "        # get the initial random noise unless the user supplied it\n",
    "\n",
    "        # Unlike in other pipelines, latents need to be generated in the target device\n",
    "        # for 1-to-1 results reproducibility with the CompVis implementation.\n",
    "        # However this currently doesn't work in `mps`.\n",
    "        latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)\n",
    "        latents_dtype = text_embeddings.dtype\n",
    "        if latents is None:\n",
    "            if self.device.type == \"mps\":\n",
    "                # randn does not exist on mps\n",
    "                latents = torch.randn(latents_shape, generator=generator, device=\"cpu\", dtype=latents_dtype).to(\n",
    "                    self.device\n",
    "                )\n",
    "            else:\n",
    "                latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)\n",
    "        else:\n",
    "            if latents.dim() != len(latents_shape):\n",
    "                raise ValueError(f\"Unexpected latents dimension, got {latents.shape}, expected {latents_shape}\")\n",
    "            if latents.shape[0] != batch_size:\n",
    "                latents = latents.repeat(batch_size, 1, 1, 1)\n",
    "            latents = latents.to(self.device)\n",
    "\n",
    "        # set timesteps\n",
    "        self.scheduler.set_timesteps(num_inference_steps)\n",
    "\n",
    "        # Some schedulers like PNDM have timesteps as arrays\n",
    "        # It's more optimized to move all timesteps to correct device beforehand\n",
    "        timesteps_tensor = self.scheduler.timesteps.to(self.device)\n",
    "\n",
    "        # scale the initial noise by the standard deviation required by the scheduler\n",
    "        latents = latents * self.scheduler.init_noise_sigma\n",
    "\n",
    "        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n",
    "        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n",
    "        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n",
    "        # and should be between [0, 1]\n",
    "        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n",
    "        extra_step_kwargs = {}\n",
    "        if accepts_eta:\n",
    "            extra_step_kwargs[\"eta\"] = eta\n",
    "\n",
    "        for i, t in enumerate(self.progress_bar(timesteps_tensor)):\n",
    "            # expand the latents if we are doing classifier free guidance\n",
    "            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n",
    "            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n",
    "\n",
    "            # predict the noise residual\n",
    "            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample\n",
    "\n",
    "            # perform guidance\n",
    "            if do_classifier_free_guidance:\n",
    "                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
    "                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
    "\n",
    "            # compute the previous noisy sample x_t -> x_t-1\n",
    "            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n",
    "\n",
    "        latents = 1 / 0.18215 * latents\n",
    "        image = self.vae.decode(latents).sample\n",
    "\n",
    "        image = (image / 2 + 0.5).clamp(0, 1)\n",
    "\n",
    "        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16\n",
    "        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n",
    "\n",
    "        if output_type == \"pil\":\n",
    "            image = self.numpy_to_pil(image)\n",
    "        \n",
    "        return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76eda94d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define noise scheduler: the parameters must match the original stable diffusion\n",
    "noise_scheduler = DDIMScheduler(\n",
    "    num_train_timesteps=1000,\n",
    "    beta_start=0.00085,\n",
    "    beta_end=0.012,\n",
    "    beta_schedule=\"scaled_linear\",\n",
    "    clip_sample=False, # don't clip sample, the x0 in stable diffusion not in range [-1, 1]\n",
    "    set_alpha_to_one=False,\n",
    "    steps_offset=1,\n",
    ")\n",
    "# load Stable Diffusion Pipeline\n",
    "pipe = SDPipeline.from_pretrained(model_path, torch_dtype=torch.float16).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bed43749",
   "metadata": {},
   "outputs": [],
   "source": [
    "def image_grid(imgs, rows, cols):\n",
    "    assert len(imgs) == rows*cols\n",
    "\n",
    "    w, h = imgs[0].size\n",
    "    grid = Image.new('RGB', size=(cols*w, rows*h))\n",
    "    grid_w, grid_h = grid.size\n",
    "    \n",
    "    for i, img in enumerate(imgs):\n",
    "        grid.paste(img, box=(i%cols*w, i//cols*h))\n",
    "    return grid"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b27bb82",
   "metadata": {},
   "source": [
    "## Interpolating between text prompts\n",
    "\n",
    "In Stable Diffusion, a text prompt is first encoded into a vector, and that encoding is used to guide the diffusion process. The latent encoding vector has shape 77x768 (that's huge!), and when we give Stable Diffusion a text prompt, we're generating images from just one such point on the latent manifold.\n",
    "\n",
    "To explore more of this manifold, we can interpolate between two text encodings and generate images at those interpolated points:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21082172",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_1 = \"A watercolor painting of a Golden Retriever at the beach\"\n",
    "prompt_2 = \"A still life DSLR photo of a bowl of fruit\"\n",
    "interpolation_steps = 8\n",
    "height = 512\n",
    "width = 512\n",
    "num_inference_steps = 25\n",
    "\n",
    "frames_per_second = 2\n",
    "\n",
    "encoding_1 = pipe.encode_text(prompt_1)\n",
    "encoding_2 = pipe.encode_text(prompt_2)\n",
    "\n",
    "interpolated_encodings = torch.cat([\n",
    "    torch.lerp(encoding_1, encoding_2, weight) for weight in np.linspace(0., 1., interpolation_steps)\n",
    "], dim=0)\n",
    "\n",
    "\n",
    "# we generate a latents (noise) to let all generated images have same start noise.\n",
    "generator = torch.Generator(device).manual_seed(12345)\n",
    "latents_shape = (1, pipe.unet.in_channels, height // 8, width // 8)\n",
    "latents = torch.randn(latents_shape, generator=generator, device=device, dtype=encoding_1.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "603046aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "image = pipe.generate_image(\n",
    "    interpolated_encodings,\n",
    "    height=height,\n",
    "    width=width,\n",
    "    num_inference_steps=num_inference_steps,\n",
    "    latents=latents)\n",
    "image[0].save(\n",
    "    \"doggo-and-fruit-8.gif\",\n",
    "    save_all=True,\n",
    "    append_images=image[1:],\n",
    "    duration=1000 // frames_per_second,\n",
    "    loop=0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "499c877b",
   "metadata": {},
   "outputs": [],
   "source": [
    "IImage(\"doggo-and-fruit-8.gif\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aadb104c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "grid = image_grid(image, 1, len(image))\n",
    "grid"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc2773e3",
   "metadata": {},
   "source": [
    "To best visualize this, we should do a much more fine-grained interpolation, using hundreds of steps. In order to keep batch size small (so that we don't OOM our GPU), this requires manually batching our interpolated encodings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b411252",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_1 = \"A watercolor painting of a Golden Retriever at the beach\"\n",
    "prompt_2 = \"A still life DSLR photo of a bowl of fruit\"\n",
    "interpolation_steps = 128\n",
    "height = 512\n",
    "width = 512\n",
    "batch_size = 8\n",
    "num_inference_steps = 25\n",
    "\n",
    "frames_per_second = 8\n",
    "\n",
    "encoding_1 = pipe.encode_text(prompt_1)\n",
    "encoding_2 = pipe.encode_text(prompt_2)\n",
    "\n",
    "interpolated_encodings = torch.cat([\n",
    "    torch.lerp(encoding_1, encoding_2, weight) for weight in np.linspace(0., 1., interpolation_steps)\n",
    "], dim=0)\n",
    "\n",
    "generator = torch.Generator(device).manual_seed(12345)\n",
    "latents_shape = (1, pipe.unet.in_channels, height // 8, width // 8)\n",
    "latents = torch.randn(latents_shape, generator=generator, device=device, dtype=encoding_1.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "126d002d",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_images = []\n",
    "for i in range(interpolation_steps // batch_size):\n",
    "    image = pipe.generate_image(\n",
    "        interpolated_encodings[i*batch_size:(i+1)*batch_size],\n",
    "        height=height,\n",
    "        width=width,\n",
    "        num_inference_steps=num_inference_steps,\n",
    "        latents=latents\n",
    "    )\n",
    "    generated_images.extend(image)\n",
    "\n",
    "generated_images[0].save(\n",
    "    \"doggo-and-fruit-128.gif\",\n",
    "    save_all=True,\n",
    "    append_images=generated_images[1:],\n",
    "    duration=1000 // frames_per_second,\n",
    "    loop=0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0153a0b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "IImage(\"doggo-and-fruit-128.gif\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e977306d",
   "metadata": {},
   "source": [
    "We can even extend this concept for more than one image. For example, we can interpolate between four prompts:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8958eadd",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_1 = \"A watercolor painting of a Golden Retriever at the beach\"\n",
    "prompt_2 = \"A still life DSLR photo of a bowl of fruit\"\n",
    "prompt_3 = \"The eiffel tower in the style of starry night\"\n",
    "prompt_4 = \"An architectural sketch of a skyscraper\"\n",
    "\n",
    "height = 512\n",
    "width = 512\n",
    "num_inference_steps = 25\n",
    "\n",
    "interpolation_steps = 6\n",
    "batch_size = 4\n",
    "\n",
    "encoding_1 = pipe.encode_text(prompt_1)\n",
    "encoding_2 = pipe.encode_text(prompt_2)\n",
    "encoding_3 = pipe.encode_text(prompt_3)\n",
    "encoding_4 = pipe.encode_text(prompt_4)\n",
    "\n",
    "interpolated_encodings_12 = torch.cat([\n",
    "    torch.lerp(encoding_1, encoding_2, weight) for weight in np.linspace(0., 1., interpolation_steps)\n",
    "], dim=0)\n",
    "interpolated_encodings_34 = torch.cat([\n",
    "    torch.lerp(encoding_3, encoding_4, weight) for weight in np.linspace(0., 1., interpolation_steps)\n",
    "], dim=0)\n",
    "interpolated_encodings = torch.cat([\n",
    "    torch.lerp(interpolated_encodings_12, interpolated_encodings_34, weight) for weight in np.linspace(0., 1., interpolation_steps)\n",
    "], dim=0)\n",
    "\n",
    "\n",
    "generator = torch.Generator(device).manual_seed(12345)\n",
    "latents_shape = (1, pipe.unet.in_channels, height // 8, width // 8)\n",
    "latents = torch.randn(latents_shape, generator=generator, device=device, dtype=encoding_1.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fe81192",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_images = []\n",
    "for i in range(interpolation_steps**2 // batch_size):\n",
    "    image = pipe.generate_image(\n",
    "        interpolated_encodings[i*batch_size:(i+1)*batch_size],\n",
    "        height=height,\n",
    "        width=width,\n",
    "        num_inference_steps=num_inference_steps,\n",
    "        latents=latents\n",
    "    )\n",
    "    generated_images.extend(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c5f4fbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid = image_grid(generated_images, interpolation_steps, interpolation_steps)\n",
    "grid"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e28c3216",
   "metadata": {},
   "source": [
    "We can also interpolate while allowing diffusion noise to vary by dropping the `latents` parameter:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "499ca272",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_images = []\n",
    "for i in range(interpolation_steps**2 // batch_size):\n",
    "    image = pipe.generate_image(\n",
    "        interpolated_encodings[i*batch_size:(i+1)*batch_size],\n",
    "        height=height,\n",
    "        width=width,\n",
    "        num_inference_steps=num_inference_steps,\n",
    "    )\n",
    "    generated_images.extend(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ead171d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid = image_grid(generated_images, interpolation_steps, interpolation_steps)\n",
    "grid"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe1ff149",
   "metadata": {},
   "source": [
    "## A walk around a text prompt\n",
    "Our next experiment will be to go for a walk around the latent manifold starting from a point produced by a particular prompt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e87b914d",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"The Eiffel Tower in the style of starry night\"\n",
    "\n",
    "height = 512\n",
    "width = 512\n",
    "num_inference_steps = 25\n",
    "\n",
    "walk_steps = 128\n",
    "step_size = 0.001\n",
    "batch_size = 8\n",
    "\n",
    "encoding = pipe.encode_text(prompt)\n",
    "\n",
    "delta = torch.ones_like(encoding) * step_size\n",
    "\n",
    "walked_encodings = []\n",
    "for step_index in range(walk_steps):\n",
    "    walked_encodings.append(encoding)\n",
    "    encoding = encoding + delta\n",
    "walked_encodings = torch.cat(walked_encodings, dim=0)\n",
    "\n",
    "generator = torch.Generator(device).manual_seed(0)\n",
    "latents_shape = (1, pipe.unet.in_channels, height // 8, width // 8)\n",
    "latents = torch.randn(latents_shape, generator=generator, device=device, dtype=encoding_1.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b1b9350",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_images = []\n",
    "for i in range( walk_steps // batch_size):\n",
    "    image = pipe.generate_image(\n",
    "        walked_encodings[i*batch_size:(i+1)*batch_size],\n",
    "        height=height,\n",
    "        width=width,\n",
    "        num_inference_steps=num_inference_steps,\n",
    "        latents=latents\n",
    "    )\n",
    "    generated_images.extend(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c6401b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "frames_per_second = 8\n",
    "generated_images[0].save(\n",
    "    \"eiffel-tower-starry-night.gif\",\n",
    "    save_all=True,\n",
    "    append_images=generated_images[1:],\n",
    "    duration=1000 // frames_per_second,\n",
    "    loop=0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0de5e6fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "IImage(\"eiffel-tower-starry-night.gif\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9311037c",
   "metadata": {},
   "source": [
    "Perhaps unsurprisingly, walking too far from the encoder's latent manifold produces images that look incoherent. Try it for yourself by setting your own prompt, and adjusting step_size to increase or decrease the magnitude of the walk. Note that when the magnitude of the walk gets large, the walk often leads into areas which produce extremely noisy images."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46f92884",
   "metadata": {},
   "source": [
    "## A circular walk through the diffusion noise space for a single prompt\n",
    "\n",
    "Our final experiment is to stick to one prompt and explore the variety of images that the diffusion model can produce from that prompt. We do this by controlling the noise that is used to seed the diffusion process.\n",
    "\n",
    "We create two noise components, `x` and `y`, and do a walk from 0 to 2π, summing the cosine of our `x` component and the sin of our `y `component to produce noise. Using this approach, the end of our walk arrives at the same noise inputs where we began our walk, so we get a \"loopable\" result!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9985b9ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"An oil paintings of cows in a field next to a windmill in Holland\"\n",
    "\n",
    "height = 512\n",
    "width = 512\n",
    "num_inference_steps = 25\n",
    "\n",
    "walk_steps = 128\n",
    "batch_size = 8\n",
    "\n",
    "encoding = pipe.encode_text(prompt)\n",
    "\n",
    "torch.manual_seed(0)\n",
    "latents_shape = (1, pipe.unet.in_channels, height // 8, width // 8)\n",
    "walk_noise_x = torch.randn(latents_shape, device=device, dtype=encoding.dtype)\n",
    "walk_noise_y = torch.randn(latents_shape, device=device, dtype=encoding.dtype)\n",
    "\n",
    "walk_scale_x = torch.cos(torch.linspace(0, 2, walk_steps) * math.pi).to(device, dtype=encoding.dtype)\n",
    "walk_scale_y = torch.sin(torch.linspace(0, 2, walk_steps) * math.pi).to(device, dtype=encoding.dtype)\n",
    "latents_x = walk_scale_x[:, None, None, None] * walk_noise_x\n",
    "latents_y = walk_scale_y[:, None, None, None] * walk_noise_y\n",
    "latents = latents_x + latents_y\n",
    "\n",
    "walked_encodings = encoding.repeat(batch_size, 1, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "533f8d41",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_images = []\n",
    "for i in range( walk_steps // batch_size):\n",
    "    image = pipe.generate_image(\n",
    "        walked_encodings,\n",
    "        height=height,\n",
    "        width=width,\n",
    "        num_inference_steps=num_inference_steps,\n",
    "        latents=latents[i*batch_size:(i+1)*batch_size]\n",
    "    )\n",
    "    generated_images.extend(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcf9519d",
   "metadata": {},
   "outputs": [],
   "source": [
    "frames_per_second = 8\n",
    "generated_images[0].save(\n",
    "    \"cows.gif\",\n",
    "    save_all=True,\n",
    "    append_images=generated_images[1:],\n",
    "    duration=1000 // frames_per_second,\n",
    "    loop=0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd020c39",
   "metadata": {},
   "outputs": [],
   "source": [
    "IImage(\"cows.gif\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
